import torch
import torch_npu
import triton


times_list_header = ["time_us", "time_us_min", "time_us_max"]

def triton_profiler_wrapper(
    call, 
):
    quantiles = [0.5, 0.2, 0.8]
    mean_ms, min_ms, max_ms = triton.testing.do_bench(call, 
                                  warmup = 5, 
                                  rep = 100, 
                                  grad_to_none = None, 
                                  quantiles = quantiles, 
                                  return_mode = "mean")
    return [ms2us(mean_ms), ms2us(min_ms), ms2us(max_ms)]

def ms2us(ms): 
    return ms * 1000.0